# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, missing-docstring
"""Unittests for tvm.script.ir_builder.tir"""
import numpy as np
import pytest
import tvm
import tvm.testing
from tvm import tir
from tvm.ir.base import assert_structural_equal
from tvm.runtime import ndarray
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import tir as T


def test_ir_builder_tir_primfunc_base():
    with IRBuilder() as ib:
        with T.prim_func():
            T.evaluate(0)

    # the prim_func generated by IRBuilder
    prim_func_actual = ib.get()

    # the expected prim_func
    prim_func_expected = tir.PrimFunc(
        params=[],
        body=tir.Evaluate(0),
        ret_type=None,
        buffer_map=None,
        attrs=None,
    )

    # Check if the generated ir is expected
    assert_structural_equal(prim_func_actual, prim_func_expected, map_free_vars=True)


def test_ir_builder_tir_primfunc_complete():
    with IRBuilder() as ib:
        with T.prim_func():
            T.arg("a", T.handle())
            T.arg("b", T.int64())
            T.arg("c", T.Buffer((128, 128), "float32"))
            d = T.arg("d", T.handle())
            e = T.arg("e", T.Buffer((1024,), "int8"))
            T.func_attr({"key": "value"})
            T.func_ret(tvm.ir.PrimType("int64"))
            buffer_d = T.match_buffer(d, (64, 64), "int64")
            T.evaluate(0)

    # the prim_func generated by IRBuilder
    prim_func_actual = ib.get()

    # the expected prim_func
    c_handle, c_buffer = tir.Var("c_handle", "handle"), tir.decl_buffer(
        (128, 128), "float32", name="c"
    )
    d_handle, d_buffer = tir.Var("d", "handle"), tir.decl_buffer((64, 64), "int64", name="d")
    e_handle, e_buffer = tir.Var("e_handle", "handle"), tir.decl_buffer((1024,), "int8", name="e")
    prim_func_expected = tir.PrimFunc(
        params=[
            tir.Var("a", "handle"),
            tir.Var("b", "int64"),
            c_handle,
            d_handle,
            e_handle,
        ],
        body=tir.Evaluate(0),
        ret_type=tvm.ir.PrimType("int64"),
        buffer_map={c_handle: c_buffer, d_handle: d_buffer, e_handle: e_buffer},
        attrs=tvm.ir.make_node("DictAttrs", key="value"),
    )

    # Check if the generated ir is expected
    assert_structural_equal(prim_func_actual, prim_func_expected, map_free_vars=True)


def test_ir_builder_tir_block_base():
    with IRBuilder() as ib:
        with T.block("block"):
            T.evaluate(0)

    # the block generated by IRBuilder
    block_realize_actual = ib.get()

    # the expected block
    block_expected = tir.Block(
        iter_vars=[],
        reads=[],
        writes=[],
        name_hint="block",
        body=tir.Evaluate(0),
        alloc_buffers=None,
        match_buffers=None,
        annotations={"tir.script_parsing_detect_access": tir.IntImm("int64", 3)},
    )
    block_realize_expected = tir.BlockRealize(
        iter_values=[],
        predicate=True,
        block=block_expected,
    )

    # Check if the generated ir is expected
    assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)


def test_ir_builder_tir_block_complete():
    with IRBuilder() as ib:
        a = T.int64()
        b = T.Buffer((128, 128), "float32")
        c = T.Buffer((128, 128), "float32")
        d = T.int32()
        e = T.Buffer((128, 128), "float32")
        f = T.int32()
        with T.block("block"):
            T.where(a > 1)
            T.reads(b[0:16, 0:16])
            T.writes(c[d:128, d:128])
            T.block_attr({"key": "value"})
            T.alloc_buffer((128, 128), "float32")
            T.match_buffer(e[0:32, 0:32], (32, 32), "float32")
            T.axis.spatial(128, f)
            T.evaluate(0)

    # the block generated by IRBuilder
    block_realize_actual = ib.get()

    # the expected block
    var_a = tir.Var("a", "int64")
    buffer_b = tir.decl_buffer((128, 128), "float32", name="b")
    buffer_c = tir.decl_buffer((128, 128), "float32", name="c")
    var_d = tir.Var("d", "int32")
    buffer_e = tir.decl_buffer((128, 128), "float32", name="c")
    var_f = tir.Var("f", "int32")
    block_expected = tir.Block(
        iter_vars=[tir.IterVar((0, 128), tir.Var("", "int32"), iter_type=tir.IterVar.DataPar)],
        reads=[buffer_b[0:16, 0:16]],
        writes=[buffer_c[var_d:128, var_d:128]],
        name_hint="block",
        body=tir.Evaluate(0),
        alloc_buffers=[tir.decl_buffer((128, 128), "float32")],
        match_buffers=[
            tir.MatchBufferRegion(tir.decl_buffer((32, 32), "float32"), buffer_e[0:32, 0:32])
        ],
        annotations={"key": "value"},
    )
    block_realize_expected = tir.BlockRealize(
        iter_values=[var_f],
        predicate=var_a > 1,
        block=block_expected,
    )

    # Check if the generated ir is expected
    assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)


def test_ir_builder_tir_axis():
    with IRBuilder() as ib:
        a = T.int32()
        b = T.int32()
        c = T.int32()
        d = T.int32()
        with T.block("block"):
            T.axis.spatial(8, a)
            T.axis.reduce(16, b)
            T.axis.scan(32, c)
            T.axis.opaque(64, d)
            T.evaluate(0)

    # the block generated by IRBuilder
    block_realize_actual = ib.get()

    # the expected block
    var_a = tir.Var("a", "int32")
    var_b = tir.Var("b", "int32")
    var_c = tir.Var("c", "int32")
    var_d = tir.Var("d", "int32")
    block_expected = tir.Block(
        iter_vars=[
            tir.IterVar((0, 8), tir.Var("", "int32"), iter_type=tir.IterVar.DataPar),
            tir.IterVar((0, 16), tir.Var("", "int32"), iter_type=tir.IterVar.CommReduce),
            tir.IterVar((0, 32), tir.Var("", "int32"), iter_type=tir.IterVar.Ordered),
            tir.IterVar((0, 64), tir.Var("", "int32"), iter_type=tir.IterVar.Opaque),
        ],
        reads=[],
        writes=[],
        name_hint="block",
        body=tir.Evaluate(0),
        annotations={"tir.script_parsing_detect_access": tir.IntImm("int64", 3)},
    )
    block_realize_expected = tir.BlockRealize(
        iter_values=[var_a, var_b, var_c, var_d],
        predicate=True,
        block=block_expected,
    )

    # Check if the generated ir is expected
    assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)


def test_ir_builder_tir_for():
    with IRBuilder() as ib:
        with T.serial(128) as a:
            with T.parallel(64) as b:
                with T.vectorized(32) as c:
                    with T.unroll(16) as d:
                        with T.thread_binding(8, thread="threadIdx.x") as e:
                            T.evaluate(0)

    # the for generated by IRBuilder
    for_actual = ib.get()

    # the expected for
    thread_binding_expected = tir.For(
        loop_var=tir.Var("", "int32"),
        min=0,
        extent=8,
        kind=tir.ForKind.THREAD_BINDING,
        body=tir.Evaluate(0),
        thread_binding=tir.IterVar(
            None, tir.Var("", "int32"), tir.IterVar.ThreadIndex, "threadIdx.x"
        ),
    )
    unroll_expected = tir.For(
        loop_var=tir.Var("", "int32"),
        min=0,
        extent=16,
        kind=tir.ForKind.UNROLLED,
        body=thread_binding_expected,
    )
    vectorized_expected = tir.For(
        loop_var=tir.Var("", "int32"),
        min=0,
        extent=32,
        kind=tir.ForKind.VECTORIZED,
        body=unroll_expected,
    )
    parallel_expected = tir.For(
        loop_var=tir.Var("", "int32"),
        min=0,
        extent=64,
        kind=tir.ForKind.PARALLEL,
        body=vectorized_expected,
    )
    for_expected = tir.For(
        loop_var=tir.Var("", "int32"),
        min=0,
        extent=128,
        kind=tir.ForKind.SERIAL,
        body=parallel_expected,
    )

    # Check if the generated ir is expected
    assert_structural_equal(for_actual, for_expected, map_free_vars=True)


def test_ir_builder_tir_for_uint():
    with IRBuilder() as ib:
        with T.serial(tir.const(128, "uint32")) as a:
            T.evaluate(0)

    # the for generated by IRBuilder
    for_actual = ib.get()

    for_expected = tir.For(
        loop_var=tir.Var("", "uint32"),
        min=tir.const(0, "uint32"),
        extent=tir.const(128, "uint32"),
        kind=tir.ForKind.SERIAL,
        body=tir.Evaluate(0),
    )

    # Check if the generated ir is expected
    assert_structural_equal(for_actual, for_expected, map_free_vars=True)


def test_ir_builder_tir_assert():
    with IRBuilder() as ib:
        with T.Assert(T.int32() == 0, message="a is 0"):
            T.evaluate(0)
    # the assert generated by IRBuilder
    assert_actual = ib.get()

    # the expected assert statement
    assert_expected = tir.AssertStmt(T.int32() == 0, tir.StringImm("a is 0"), tir.Evaluate(0))

    # Check if the generated ir is expected
    assert_structural_equal(assert_actual, assert_expected, map_free_vars=True)


def test_ir_builder_tir_let():
    with IRBuilder() as ib:
        with T.LetStmt(tir.IntImm("int32", 2)) as v:
            T.evaluate(0)
    # the let binding generated by IRBuilder
    let_actual = ib.get()

    # the expected Let statement
    let_expected = tir.LetStmt(T.int32(), tir.IntImm("int32", 2), tir.Evaluate(0))

    # Check if the generated ir is expected
    assert_structural_equal(let_actual, let_expected, map_free_vars=True)


def test_ir_builder_tir_realize():
    buffer_a = T.Buffer((128, 128), "float32")
    with IRBuilder() as ib:
        with T.realize(buffer_a[0:128, 0:128], "test_storage_scope", True):
            T.evaluate(0)

    # the buffer realization generated by IRBuilder
    realize_actual = ib.get()

    # the expected buffer realization
    buffer_realize = tir.BufferRealize(
        buffer_a, [tvm.ir.Range(0, 128), tvm.ir.Range(0, 128)], True, tir.Evaluate(0)
    )
    expected_realize = tir.AttrStmt(
        buffer_a, "realize_scope", tir.StringImm("test_storage_scope"), buffer_realize
    )

    # Check if the generated ir is expected
    assert_structural_equal(realize_actual, expected_realize, map_free_vars=True)


def test_ir_builder_tir_thread():
    with IRBuilder() as ib:
        with T.prim_func():
            brow = T.env_thread("blockIdx.y")
            with T.launch_thread(brow, 1):
                T.evaluate(0)

    # the prim_func generated by IRBuilder
    ir_actual = ib.get()

    # the expected prim_func
    iter_var = tir.IterVar((0, 1), "v", iter_type=1, thread_tag="blockIdx.y")
    attr_stmt = tir.AttrStmt(iter_var, "thread_extent", 1, tir.Evaluate(0))
    func = tir.PrimFunc([], attr_stmt)

    # Check if the generated ir is expected
    assert_structural_equal(ir_actual, func, map_free_vars=True)


def test_ir_builder_tir_allocate():
    with IRBuilder() as ib:
        with T.allocate([10], "float32", scope="local"):
            T.evaluate(1)

    # the allocate generated by IRBuilder
    ir_actual = ib.get()

    # the expected allocate
    buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("float32"), "local"))
    ir_expected = tir.Allocate(
        buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), tir.Evaluate(1)
    )

    # Check if the generated ir is expected
    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)


def test_ir_builder_tir_allocate_const():
    data = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
    with IRBuilder() as ib:
        with T.allocate_const(data, "int32", [10]):
            T.evaluate(1)

    # the allocate const generated by IRBuilder
    ir_actual = ib.get()

    # the expected allocate const
    buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("int32")))
    ir_expected = tir.AllocateConst(
        buffer_var,
        "int32",
        [10],
        ndarray.array(np.asarray(data, "int32")),
        tir.Evaluate(1),
        annotations={},
    )

    # Check if the generated ir is expected
    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)


def test_ir_builder_tir_while():
    with IRBuilder() as ib:
        with T.While(T.int32() > 0):
            T.evaluate(0)

    # the while generated by IRBuilder
    ir_actual = ib.get()

    # the expected while
    ir_expected = tir.While(tir.Var("x", "int32") > 0, tir.Evaluate(0))

    # Check if the generated ir is expected
    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)


def test_ir_builder_tir_if_then_else():
    with IRBuilder() as ib:
        with T.If(T.int32() < 12):
            with T.Then():
                T.evaluate(T.int32(0))
            with T.Else():
                T.evaluate(T.int32(1))

    # the if_then_else generated by IRBuilder
    ir_actual = ib.get()

    # the expected if_then_else
    ir_expected = tir.IfThenElse(
        tir.Var("c", "int32") < 12,
        tir.Evaluate(tir.IntImm("int32", 0)),
        tir.Evaluate(tir.IntImm("int32", 1)),
    )

    # Check if the generated ir is expected
    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)


def test_ir_builder_tir_buffer_store():
    buffer_a = T.Buffer((10, 10), "float32")
    i = T.int32()
    with IRBuilder() as ib:
        T.buffer_store(buffer_a, 0.1, [0, i])

    # the buffer store generated by IRBuilder
    ir_actual = ib.get()

    # the expected buffer store
    ir_expected = tir.BufferStore(buffer_a, 0.1, [0, i])

    # Check if the generated ir is expected
    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)


def test_ir_builder_tir_buffer_store_scalable_vec():
    buffer_a = T.Buffer((30,), "float32")
    value = T.broadcast(0.11, 4 * tvm.tir.vscale())
    index = T.ramp(0, 1, 4 * tvm.tir.vscale())

    with IRBuilder() as ib:
        T.buffer_store(buffer_a, value, [index])

    # the buffer store generated by IRBuilder
    ir_actual = ib.get()

    # the expected buffer store
    ir_expected = tir.BufferStore(buffer_a, value, [index])

    # Check if the generated ir is expected
    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)


def test_ir_builder_tir_buffer_store_predicate():
    buffer_a = T.Buffer((30,), "float32")
    value = T.broadcast(0.11, T.vscale() * 4)
    index = T.ramp(0, 1, T.vscale() * 4)
    predicate = T.broadcast(T.bool(True), T.vscale() * 4)

    with IRBuilder() as ib:
        T.buffer_store(buffer_a, value, [index], predicate)

    ir_actual = ib.get()
    ir_expected = tir.BufferStore(buffer_a, value, [index], predicate)
    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)


def test_ir_builder_tir_prefetch():
    with IRBuilder() as ib:
        buffer_a = T.Buffer((128, 128), "float32")
        T.prefetch(buffer_a, [])

    # the prefetch generated by IRBuilder
    ir_actual = ib.get()

    # the expected prefetch
    ir_expected = tir.Prefetch(buffer_a, [])

    # Check if the generated ir is expected
    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)


def test_ir_builder_tir_evaluate():
    with IRBuilder() as ib:
        T.evaluate(0)
    # the evaluate generated by IRBuilder
    eval_actual = ib.get()

    # the expected evaluate
    eval_expected = tir.Evaluate(0)

    # Check if the generated ir is expected
    assert_structural_equal(eval_actual, eval_expected, map_free_vars=True)


def test_ir_builder_tir_decl_buffer():
    with IRBuilder() as ib:
        with T.decl_buffer([128, 128], "float32"):
            T.evaluate(0)

    # the decl_buffer generated by IRBuilder
    ir_actual = ib.get()

    # the expected decl_buffer
    buffer = T.Buffer((128, 128), "float32")
    ir_expected = tir.Allocate(
        buffer.data,
        "float32",
        (128, 128),
        tir.IntImm("bool", True),
        tir.DeclBuffer(buffer, tir.Evaluate(0)),
    )

    # Check if the generated ir is expected
    assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)


def test_ir_builder_tir_inline():
    with IRBuilder() as ib:
        m, n = T.meta_var(1), T.meta_var(2)
        a, b = T.meta_var([3, 4])
        T.evaluate(m.value + n.value + a.value + b.value)
    # the evaluate generated by IRBuilder
    eval_actual = ib.get()

    # the expected evaluate
    eval_expected = tir.Evaluate(10)

    # Check if the generated ir is expected
    assert_structural_equal(eval_actual, eval_expected, map_free_vars=True)


if __name__ == "__main__":
    tvm.testing.main()
